[torch.compile] Bunch of small changes needed for enabling torch.compile#3130
[torch.compile] Bunch of small changes needed for enabling torch.compile#3130pggPL wants to merge 5 commits into
Conversation
…stants; fix SP memory leak; test suite hook-up Wrap CommOverlapCore pybind11 methods that return compile-time constants so torch.compile(fullgraph=True) can trace through them without graph breaks: - `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py; `_ub_is_fp8()` in gemm.py - `with_cublasmp()` → `ub_is_cublasmp()` in base.py All callers in linear.py, layernorm_linear.py, layernorm_mlp.py, base.py, gemm.py, userbuffers_backward_linear.py and userbuffers_forward_linear.py updated. Fix quantized grad_output not being freed early for column-parallel SP backward. Row-parallel SP already called clear_tensor_data(grad_output) to release the gathered tensor; column-parallel SP quantizes grad_output to Float8TensorStorage but never freed it before returning. Under torch.compile reduce-overhead this leaves 3 live pool tensors at recording end and triggers "Detected 3 tensor(s) in the cudagraph pool not tracked as outputs". Extend the existing clear_tensor_data guard to cover both parallel modes. Fix custom-recipe quantizer state being re-initialised on every forward call even when the recipe object has not changed. The existing early-exit for CustomRecipeState was missing an identity check on the recipe object, so any repeated call with the same recipe would bypass the early-return and rebuild quantizers unnecessarily. Add `if recipe_state.recipe is recipe: return` to restore the intended caching behaviour. Add test_torch_compile.py to L0_pytorch_unittest so the autocast and existing compile tests run in CI. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…-accumulator booleans LinearBwdArgs stored the entire FP8 recipe object so the backward could extract fp8_gemm_dgrad.use_split_accumulator and fp8_gemm_wgrad.use_split_accumulator at GEMM time. Recipe objects hold process-group references and are not serialisable as compile-time constants, making them incompatible with torch.compile custom-op paths. Replace fp8_recipe with two plain bool fields: - dgrad_use_split_accumulator (default _2X_ACC_DGRAD) - wgrad_use_split_accumulator (default _2X_ACC_WGRAD) These are resolved once in _linear_setup_ctx and passed into the args struct, so the backward consumes scalars instead of a live recipe object. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR bundles five targeted
Confidence Score: 5/5All changes are well-scoped and non-breaking; the column-SP FP8 tensor free is correctly placed after the wgrad GEMM and the freed variable is not accessed again. Each change is narrowly targeted: the split-accumulator refactor preserves the same defaults for non-FP8 and reproduces the same hasattr guards for FP8 recipes; the grad_output free only triggers after the GEMM that consumed it; the CustomRecipeState identity check uses No files require special attention. Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant M as Linear.forward()
participant GSM as FP8GlobalStateManager
participant FA as LinearFwdArgs
participant BA as LinearBwdArgs
participant BW as _linear_backward()
M->>GSM: get_fp8_recipe()
GSM-->>M: _recipe
M->>M: resolve dgrad/wgrad split-accumulator bools
M->>FA: fwd_args(dgrad_use_split_accumulator, wgrad_use_split_accumulator)
FA->>BA: _linear_setup_ctx copies plain bools
Note over FA,BA: No recipe object stored
BW->>BW: "use_split_accumulator = bwd_args.dgrad_use_split_accumulator"
BW->>BW: dgrad GEMM
BW->>BW: "use_split_accumulator = bwd_args.wgrad_use_split_accumulator"
BW->>BW: wgrad GEMM
BW->>BW: clear_tensor_data(grad_output) [row-SP or col-SP+FP8]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant M as Linear.forward()
participant GSM as FP8GlobalStateManager
participant FA as LinearFwdArgs
participant BA as LinearBwdArgs
participant BW as _linear_backward()
M->>GSM: get_fp8_recipe()
GSM-->>M: _recipe
M->>M: resolve dgrad/wgrad split-accumulator bools
M->>FA: fwd_args(dgrad_use_split_accumulator, wgrad_use_split_accumulator)
FA->>BA: _linear_setup_ctx copies plain bools
Note over FA,BA: No recipe object stored
BW->>BW: "use_split_accumulator = bwd_args.dgrad_use_split_accumulator"
BW->>BW: dgrad GEMM
BW->>BW: "use_split_accumulator = bwd_args.wgrad_use_split_accumulator"
BW->>BW: wgrad GEMM
BW->>BW: clear_tensor_data(grad_output) [row-SP or col-SP+FP8]
Reviews (3): Last reviewed commit: "Provide explicit QuantizerRoles in torch..." | Re-trigger Greptile |
| @torch.compiler.assume_constant_result | ||
| def get_ub_is_fp8(name: str, use_fp8: bool) -> bool: | ||
| """Query is_fp8_ubuf for a named UB communicator; treated as compile-time constant.""" | ||
| return get_ub(name, use_fp8).is_fp8_ubuf() |
There was a problem hiding this comment.
assume_constant_result can become stale after destroy_ub() + re-init
@torch.compiler.assume_constant_result caches the return value per (name, use_fp8) argument pair for the lifetime of a compiled region. If destroy_ub() is called and UB communicators are re-initialized with different FP8 settings (e.g. in a test harness that re-creates the communicators), the cached is_fp8_ubuf() result would be silently stale until the next recompile. In production training this should not happen — UB is typically initialized once — but test suites that tear down and rebuild UB communicators between cases could observe incorrect fp8_output/fp8_grad flags without triggering a recompile.
|
/te-ci pytorch L1 |
…t_result get_ub_is_fp8 bakes is_fp8_ubuf() as a compile-time constant; without a reset, destroy_ub + re-init with different FP8 settings would read stale values until recompile. Only affects in-memory caches, not disk. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
ToyLinear now overrides get_quantizer_roles so CustomRecipeState doesn't hit the no-roles warning, which graph-breaks under fullgraph=True. qfactory dispatches on role.tensor_type instead of a pre-baked string key. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Description
Small standalone fixes extracted from a larger torch.compile branch, going directly from main. Two independent changes: making Userbuffers pybind11 queries compile-friendly, and freeing quantized grad_output early for column-parallel SP. Plus a custom-recipe caching fix, a split-accumulator refactor, and a CI test hook-up.
Type of change
Changes
Checklist: